{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Guidance acceleration\n",
    "\n",
    "When multiple generation or LLM-directed control flow statements are used in a single Guidance program then we can significantly improve inference performance by maintaining a session state with the LLM inferencer and so reusing the Key/Value caches as we progress through the prompt. This is much faster than letting the model generate all of the structural tokens itself (for example if the structure was demonstrated using a one-shot example), and also faster that simply recalling the model without any state at each point in the Guidance program.\n",
    "\n",
    "We call this \"guidance acceleration\" and an early implementation of this is available in `guidance.llms.Transformers` as demonstrated below. Note we also show the performance impact of token healing (which has not yet been highly optimized)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# prefix tokens: 36\n",
      "With guidance accelaration and token healing: 1.780165433883667\n",
      "                  With guidance accelaration: 1.4087750911712646\n",
      "               Without guidance accelaration: 2.530573606491089\n",
      "       Single generation call of same length: 13.744869470596313\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import guidance\n",
    "\n",
    "# Define a trivial string we can extend with small models easily\n",
    "reps = 20\n",
    "prefix = \"Repeat this. Repeat this. \"*5 + \"Repeat this. Repeat this.\"\n",
    "llm = guidance.llms.Transformers('gpt2', device='cpu')\n",
    "print(\"# prefix tokens:\", len(llm.encode(prefix)))\n",
    "del llm.model_obj\n",
    "\n",
    "model = 'gpt2-large'\n",
    "device = 'cuda'\n",
    "llm = guidance.llms.Transformers(model, device=device)\n",
    "guidance.llms.Transformers.cache.clear()\n",
    "start = time.time()\n",
    "template = \"\"\"{{prefix}}{{gen 'story' stop=\".\" max_tokens=50}}.\"\"\"*reps\n",
    "program = guidance(template, llm=llm)\n",
    "a = program(prefix=prefix)\n",
    "end = time.time()\n",
    "print(f\"With guidance accelaration and token healing:\", end - start)\n",
    "del llm.model_obj\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "llm = guidance.llms.Transformers(model, device=device, token_healing=False)\n",
    "guidance.llms.Transformers.cache.clear()\n",
    "start = time.time()\n",
    "template = \"\"\"{{prefix}}{{gen 'story' stop=\".\" max_tokens=50}}.\"\"\"*reps\n",
    "program = guidance(template, llm=llm)\n",
    "b = program(prefix=prefix)\n",
    "end = time.time()\n",
    "print(f\"                  With guidance accelaration:\", end - start)\n",
    "del llm.model_obj\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "llm = guidance.llms.Transformers(model, device=device, acceleration=False, token_healing=False)\n",
    "guidance.llms.Transformers.cache.clear()\n",
    "start = time.time()\n",
    "template = \"\"\"{{prefix}}{{gen 'story' stop=\".\" max_tokens=50}}.\"\"\"*reps\n",
    "program = guidance(template, llm=llm)\n",
    "b = program(prefix=prefix)\n",
    "end = time.time()\n",
    "print(f\"               Without guidance accelaration:\", end - start)\n",
    "del llm.model_obj\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "llm = guidance.llms.Transformers(model, device=device)\n",
    "guidance.llms.Transformers.cache.clear()\n",
    "start = time.time()\n",
    "template = \"\"\"{{prefix}}{{gen 'story' max_tokens=max_tokens}}\"\"\"\n",
    "program = guidance(template, llm=llm)\n",
    "b = program(prefix=prefix, max_tokens=len(llm.encode(str(a))) - len(llm.encode(\"Repeat this. Repeat this.\")))\n",
    "end = time.time()\n",
    "print(f\"       Single generation call of same length:\", end - start)\n",
    "del llm.model_obj\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<hr style=\"height: 1px; opacity: 0.5; border: none; background: #cccccc;\">\n",
    "<div style=\"text-align: center; opacity: 0.5\">Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adatest",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
